#include "MultiSwap.h"


static int mod232inv(int x)
{
	int rval = x;
	int i;

	for (i = 0; i < 30; i++)
		rval *= rval * x;

	return rval;
}


void MultiSwapSetKey(MULTISWAPKEY * out, unsigned int *datain)
{
	int i;

	for (i = 0; i < 5; i++) {
		out->round[0].multikey[i] = datain[i] | 1;
		out->round[0].multiinv[i] = mod232inv(datain[i] | 1);
		out->round[1].multikey[i] = datain[6 + i] | 1;
		out->round[1].multiinv[i] = mod232inv(datain[6 + i] | 1);
	}

	/* Note that the "| 1" part isn't necessary here, since we don't
	 * need multiplicative inverses.   But they do it anyway!
	 */

	out->round[0].additive = datain[5] | 1;
	out->round[1].additive = datain[11] | 1;
}


void MultiSwapEncode(MULTISWAPKEY * key, unsigned int *state,
		     unsigned int *datain, unsigned int *dataout)
{
	unsigned int cst = state[0];

	cst += datain[0];
	cst *= key->round[0].multikey[0];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multikey[1];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multikey[2];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multikey[3];
	cst = (cst >> 16) | (cst << 16);
	cst = cst * key->round[0].multikey[4] + key->round[0].additive;
	dataout[1] = cst + state[1];

	cst += datain[1];
	cst *= key->round[1].multikey[0];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multikey[1];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multikey[2];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multikey[3];
	cst = (cst >> 16) | (cst << 16);
	cst = cst * key->round[1].multikey[4] + key->round[1].additive;
	dataout[1] += cst;

	dataout[0] = cst;
}


void MultiSwapMAC(MULTISWAPKEY * key, unsigned int *data, int num64bits,
		  unsigned int *out)
{
	unsigned int state[2];
	unsigned int newstate[2] = { 0, 0 };
	int i;

	for (i = 0; i < num64bits; i++) {
		state[0] = newstate[0];
		state[1] = newstate[1];
		MultiSwapEncode(key, state, data + 2 * i, newstate);
	}

	out[0] = newstate[0];
	out[1] = newstate[1];
}


void MultiSwapDecode(MULTISWAPKEY * key, unsigned int *state,
		     unsigned int *datain, unsigned int *dataout)
{
	unsigned int cst = datain[0];
	unsigned int tmp;

	dataout[1] = datain[1] - cst;
	cst = (cst - key->round[1].additive) * key->round[1].multiinv[4];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multiinv[3];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multiinv[2];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multiinv[1];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[1].multiinv[0];
	tmp = dataout[1] - state[1];
	dataout[1] = cst - tmp;
	cst = tmp;

	cst = (cst - key->round[0].additive) * key->round[0].multiinv[4];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multiinv[3];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multiinv[2];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multiinv[1];
	cst = (cst >> 16) | (cst << 16);
	cst *= key->round[0].multiinv[0];

	dataout[0] = cst - state[0];
}
